{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import numpy as np\n",
    "import sympy as sp\n",
    "import torch\n",
    "\n",
    "from src.utils import AttrDict\n",
    "from src.envs import build_env\n",
    "from src.model import build_modules\n",
    "\n",
    "from src.utils import to_cuda\n",
    "from src.envs.sympy_utils import simplify"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Build environment / Reload model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# trained model, e.g. \"wget https://dl.fbaipublicfiles.com/SymbolicMathematics/models/fwd_bwd.pth\"\n",
    "model_path = './fwd_bwd.pth'\n",
    "assert os.path.isfile(model_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "params = params = AttrDict({\n",
    "\n",
    "    # environment parameters\n",
    "    'env_name': 'char_sp',\n",
    "    'int_base': 10,\n",
    "    'balanced': False,\n",
    "    'positive': True,\n",
    "    'precision': 10,\n",
    "    'n_variables': 1,\n",
    "    'n_coefficients': 0,\n",
    "    'leaf_probs': '0.75,0,0.25,0',\n",
    "    'max_len': 512,\n",
    "    'max_int': 5,\n",
    "    'max_ops': 15,\n",
    "    'max_ops_G': 15,\n",
    "    'clean_prefix_expr': True,\n",
    "    'rewrite_functions': '',\n",
    "    'tasks': 'prim_fwd',\n",
    "    'operators': 'add:10,sub:3,mul:10,div:5,sqrt:4,pow2:4,pow3:2,pow4:1,pow5:1,ln:4,exp:4,sin:4,cos:4,tan:4,asin:1,acos:1,atan:1,sinh:1,cosh:1,tanh:1,asinh:1,acosh:1,atanh:1',\n",
    "\n",
    "    # model parameters\n",
    "    'cpu': False,\n",
    "    'emb_dim': 1024,\n",
    "    'n_enc_layers': 6,\n",
    "    'n_dec_layers': 6,\n",
    "    'n_heads': 8,\n",
    "    'dropout': 0,\n",
    "    'attention_dropout': 0,\n",
    "    'sinusoidal_embeddings': False,\n",
    "    'share_inout_emb': True,\n",
    "    'reload_model': model_path,\n",
    "\n",
    "})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "env = build_env(params)\n",
    "x = env.local_dict['x']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "modules = build_modules(env, params)\n",
    "encoder = modules['encoder']\n",
    "decoder = modules['decoder']"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Start from a function F, compute its derivative f = F', and try to recover F from f"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# here you can modify the integral function the model has to predict, F\n",
    "F_infix = 'x * tan(exp(x)/x)'\n",
    "F_infix = 'x * cos(x**2) * tan(x)'\n",
    "F_infix = 'cos(x**2 * exp(x * cos(x)))'\n",
    "F_infix = 'ln(cos(x + exp(x)) * sin(x**2 + 2) * exp(x) / x)'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/latex": [
       "$\\displaystyle \\log{\\left(\\frac{e^{x} \\sin{\\left(x^{2} + 2 \\right)} \\cos{\\left(x + e^{x} \\right)}}{x} \\right)}$"
      ],
      "text/plain": [
       "log(exp(x)*sin(x**2 + 2)*cos(x + exp(x))/x)"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# F (integral, that the model will try to predict)\n",
    "F = sp.S(F_infix, locals=env.local_dict)\n",
    "F"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/latex": [
       "$\\displaystyle \\frac{x \\left(2 e^{x} \\cos{\\left(x + e^{x} \\right)} \\cos{\\left(x^{2} + 2 \\right)} - \\frac{\\left(e^{x} + 1\\right) e^{x} \\sin{\\left(x + e^{x} \\right)} \\sin{\\left(x^{2} + 2 \\right)}}{x} + \\frac{e^{x} \\sin{\\left(x^{2} + 2 \\right)} \\cos{\\left(x + e^{x} \\right)}}{x} - \\frac{e^{x} \\sin{\\left(x^{2} + 2 \\right)} \\cos{\\left(x + e^{x} \\right)}}{x^{2}}\\right) e^{- x}}{\\sin{\\left(x^{2} + 2 \\right)} \\cos{\\left(x + e^{x} \\right)}}$"
      ],
      "text/plain": [
       "x*(2*exp(x)*cos(x + exp(x))*cos(x**2 + 2) - (exp(x) + 1)*exp(x)*sin(x + exp(x))*sin(x**2 + 2)/x + exp(x)*sin(x**2 + 2)*cos(x + exp(x))/x - exp(x)*sin(x**2 + 2)*cos(x + exp(x))/x**2)*exp(-x)/(sin(x**2 + 2)*cos(x + exp(x)))"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# f (F', that the model will take as input)\n",
    "f = F.diff(x)\n",
    "f"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Compute prefix representations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "F prefix: ['ln', 'mul', 'pow', 'x', 'INT-', '1', 'mul', 'cos', 'add', 'x', 'exp', 'x', 'mul', 'exp', 'x', 'sin', 'add', 'INT+', '2', 'pow', 'x', 'INT+', '2']\n",
      "f prefix: ['mul', 'x', 'mul', 'pow', 'cos', 'add', 'x', 'exp', 'x', 'INT-', '1', 'mul', 'pow', 'sin', 'add', 'INT+', '2', 'pow', 'x', 'INT+', '2', 'INT-', '1', 'mul', 'add', 'mul', 'INT+', '2', 'mul', 'cos', 'add', 'INT+', '2', 'pow', 'x', 'INT+', '2', 'mul', 'cos', 'add', 'x', 'exp', 'x', 'exp', 'x', 'add', 'mul', 'pow', 'x', 'INT-', '1', 'mul', 'cos', 'add', 'x', 'exp', 'x', 'mul', 'exp', 'x', 'sin', 'add', 'INT+', '2', 'pow', 'x', 'INT+', '2', 'add', 'mul', 'INT-', '1', 'mul', 'pow', 'x', 'INT-', '2', 'mul', 'cos', 'add', 'x', 'exp', 'x', 'mul', 'exp', 'x', 'sin', 'add', 'INT+', '2', 'pow', 'x', 'INT+', '2', 'mul', 'INT-', '1', 'mul', 'pow', 'x', 'INT-', '1', 'mul', 'add', 'INT+', '1', 'exp', 'x', 'mul', 'exp', 'x', 'mul', 'sin', 'add', 'INT+', '2', 'pow', 'x', 'INT+', '2', 'sin', 'add', 'x', 'exp', 'x', 'exp', 'mul', 'INT-', '1', 'x']\n"
     ]
    }
   ],
   "source": [
    "F_prefix = env.sympy_to_prefix(F)\n",
    "f_prefix = env.sympy_to_prefix(f)\n",
    "print(f\"F prefix: {F_prefix}\")\n",
    "print(f\"f prefix: {f_prefix}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Encode input"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "x1_prefix = env.clean_prefix(['sub', 'derivative', 'f', 'x', 'x'] + f_prefix)\n",
    "x1 = torch.LongTensor(\n",
    "    [env.eos_index] +\n",
    "    [env.word2id[w] for w in x1_prefix] +\n",
    "    [env.eos_index]\n",
    ").view(-1, 1)\n",
    "len1 = torch.LongTensor([len(x1)])\n",
    "x1, len1 = to_cuda(x1, len1)\n",
    "\n",
    "with torch.no_grad():\n",
    "    encoded = encoder('fwd', x=x1, lengths=len1, causal=False).transpose(0, 1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Decode with beam search"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "beam_size = 10\n",
    "with torch.no_grad():\n",
    "    _, _, beam = decoder.generate_beam(encoded, len1, beam_size=beam_size, length_penalty=1.0, early_stopping=1, max_len=200)\n",
    "    assert len(beam) == 1\n",
    "hypotheses = beam[0].hyp\n",
    "assert len(hypotheses) == beam_size"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Print results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Input function f: x*(2*exp(x)*cos(x + exp(x))*cos(x**2 + 2) - (exp(x) + 1)*exp(x)*sin(x + exp(x))*sin(x**2 + 2)/x + exp(x)*sin(x**2 + 2)*cos(x + exp(x))/x - exp(x)*sin(x**2 + 2)*cos(x + exp(x))/x**2)*exp(-x)/(sin(x**2 + 2)*cos(x + exp(x)))\n",
      "Reference function F: log(exp(x)*sin(x**2 + 2)*cos(x + exp(x))/x)\n",
      "\n",
      "-0.00003  OK  log(exp(x)*sin(x**2 + 2)*cos(x + exp(x))/x)\n",
      "-0.28475  OK  log(exp(x)*sin((x**3 + 2*x)/x)*cos(x + exp(x))/x)\n",
      "-0.28592  OK  log(exp(x)*sin(x*(x + 2/x))*cos(x + exp(x))/x)\n",
      "-0.35794  OK  log(exp(x)*sin(x*(x + 1) - x + 2)*cos(x + exp(x))/x)\n",
      "-0.37952  NO  log(exp(x)*sin(x**2*(x + 2/x))*cos(x + exp(x))/x)\n",
      "-0.38034  NO  log(exp(x)*sin(x**2 + 2)*cos(x + sinh(x) + cosh(x))/x)\n",
      "-0.39518  OK  atan(tan(log(exp(x)*sin(x**2 + 2)*cos(x + exp(x))/x)))\n",
      "-0.39689  OK  log(exp(x)*sin(x*(x - 1) + x + 2)*cos(x + exp(x))/x)\n",
      "-0.43203  NO  log(exp(x)*sin((x**2 + 2)**2)*cos(x + exp(x))/x)\n",
      "-0.44538  NO  log(exp(x)*sin(x**2 + 2*x)*cos(x + exp(x))/x)\n"
     ]
    }
   ],
   "source": [
    "print(f\"Input function f: {f}\")\n",
    "print(f\"Reference function F: {F}\")\n",
    "print(\"\")\n",
    "\n",
    "for score, sent in sorted(hypotheses, key=lambda x: x[0], reverse=True):\n",
    "\n",
    "    # parse decoded hypothesis\n",
    "    ids = sent[1:].tolist()                  # decoded token IDs\n",
    "    tok = [env.id2word[wid] for wid in ids]  # convert to prefix\n",
    "\n",
    "    try:\n",
    "        hyp = env.prefix_to_infix(tok)       # convert to infix\n",
    "        hyp = env.infix_to_sympy(hyp)        # convert to SymPy\n",
    "\n",
    "        # check whether we recover f if we differentiate the hypothesis\n",
    "        # note that sometimes, SymPy fails to show that hyp' - f == 0, and the result is considered as invalid, although it may be correct\n",
    "        res = \"OK\" if simplify(hyp.diff(x) - f, seconds=1) == 0 else \"NO\"\n",
    "\n",
    "    except:\n",
    "        res = \"INVALID PREFIX EXPRESSION\"\n",
    "        hyp = tok\n",
    "\n",
    "    # print result\n",
    "    print(\"%.5f  %s  %s\" % (score, res, hyp))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
